﻿#pragma kernel UpdateSkinConstraints
#pragma kernel UpdateClothMesh

#include "MathUtils.cginc"

struct Influence
{
    int index;
    float weight;
};

struct SkinmapData
{
    int firstInfluence;
    int firstInfNumber;
    int firstParticleBindPose;
    
    int firstSkinWeight;
    int firstSkinWeightNumber;
    int firstBoneBindPose;

    int bindPoseCount;
};

struct SkeletonData
{
    int firstBone;
    int boneCount;
};

struct MeshData
{
    int firstVertex;
    int vertexCount;

    int firstTriangle;
    int triangleCount;
};

StructuredBuffer<int> particleIndices;
StructuredBuffer<int> rendererIndices; // for each vertex/particle, index of its renderer.

StructuredBuffer<float4> renderablePositions;
StructuredBuffer<quaternion> renderableOrientations;
StructuredBuffer<float4> colors;

StructuredBuffer<float4> restPositions;
StructuredBuffer<quaternion> restOrientations;

StructuredBuffer<int> skinConstraintOffsets;

StructuredBuffer<int> skinmapIndices; // for each renderer, index of its skinmap.
StructuredBuffer<int> meshIndices; // for each renderer, index of its mesh.
StructuredBuffer<int> skeletonIndices; // for each renderer, index of its skeleton.

StructuredBuffer<int> particleOffsets; // for each renderer, index of its first particle in the batch.
StructuredBuffer<int> vertexOffsets;  // for each renderer, index of its first vertex in the batch.

StructuredBuffer<SkinmapData> skinData;
StructuredBuffer<Influence> influences;
StructuredBuffer<int> influenceOffsets;
StructuredBuffer<float4x4> particleBindMatrices;
StructuredBuffer<float4x4> boneBindMatrices;

StructuredBuffer<SkeletonData> skeletonData;
StructuredBuffer<float3> bonePos;
StructuredBuffer<quaternion> boneRot;
StructuredBuffer<float3> boneScl;

StructuredBuffer<MeshData> meshData;
StructuredBuffer<float3> positions;
StructuredBuffer<float3> normals;
StructuredBuffer<float4> tangents;

RWStructuredBuffer<float4> skinConstraintPoints;
RWStructuredBuffer<float4> skinConstraintNormals;
RWByteAddressBuffer vertices;

// Variables set from the CPU
uint vertexCount;
uint constraintCount;
float4x4 world2Solver;

[numthreads(128, 1, 1)]
void UpdateSkinConstraints (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;

    if (i >= constraintCount) return;

    int rendererIndex = rendererIndices[i];

    // get skin map and mesh data:
    SkinmapData skin = skinData[skinmapIndices[rendererIndex]];
    SkeletonData skel = skeletonData[skeletonIndices[rendererIndex]];

    // invalid skeleton:
    if (skel.boneCount <= 0)
        return;

    // get index of this particle in its original actor: 
    int originalParticleIndex = i - particleOffsets[rendererIndex];

    // get first influence and amount of influences for this particle:
    int influenceStart = influenceOffsets[skin.firstSkinWeightNumber + originalParticleIndex];
    int influenceCount = influenceOffsets[skin.firstSkinWeightNumber + originalParticleIndex + 1] - influenceStart;

    float4 pos = FLOAT4_ZERO;
    float4 norm = FLOAT4_ZERO;

    for (int k = influenceStart; k < influenceStart + influenceCount; ++k)
    {
        Influence inf = influences[skin.firstSkinWeight + k];

        float4x4 bind = boneBindMatrices[skin.firstBoneBindPose + inf.index];

        int boneIndex = skel.firstBone + inf.index;
        float4x4 deform = TRS(bonePos[boneIndex], boneRot[boneIndex], boneScl[boneIndex]);
        float4x4 trfm = mul(world2Solver, mul(deform, bind));

        pos.xyz += mul(trfm, float4(restPositions[particleIndices[i]].xyz, 1)).xyz * inf.weight;
        norm.xyz += mul(trfm, float4(rotate_vector(restOrientations[particleIndices[i]], float3(0, 0, 1)), 0)).xyz * inf.weight;
    }

    int constraintIndex = skinConstraintOffsets[rendererIndex] + originalParticleIndex;
    skinConstraintPoints[constraintIndex] = pos;
    skinConstraintNormals[constraintIndex] = norm;
    
}

[numthreads(128, 1, 1)]
void UpdateClothMesh (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;

    if (i >= vertexCount) return;

    int rendererIndex = rendererIndices[i];

    // get skin map and mesh data:
    SkinmapData skin = skinData[skinmapIndices[rendererIndex]];
    MeshData mesh = meshData[meshIndices[rendererIndex]];
    
    // get index of this vertex in its original mesh:
    int originalVertexIndex = i - vertexOffsets[rendererIndex];

    // get index of the vertex in the mesh batch:
    int batchedVertexIndex = mesh.firstVertex + originalVertexIndex;

    // get first influence and amount of influences for this vertex:
    int influenceStart = influenceOffsets[skin.firstInfNumber + originalVertexIndex];
    int influenceCount = influenceOffsets[skin.firstInfNumber + originalVertexIndex + 1] - influenceStart;
    
    float3 position = float3(0,0,0);
    float3 normal = float3(0,0,0);
    float4 tangent = FLOAT4_ZERO;
    float4 color = FLOAT4_ZERO;
    
    for (int k = influenceStart; k < influenceStart + influenceCount; ++k)
    {
        Influence inf = influences[skin.firstInfluence + k];
        
        int p = particleIndices[particleOffsets[rendererIndex] + inf.index];

        float4x4 deform = mul(m_translate(FLOAT4X4_IDENTITY,renderablePositions[p].xyz), q_toMatrix(renderableOrientations[p]));
        float4x4 trfm = mul(deform, particleBindMatrices[skin.firstParticleBindPose + inf.index]);

        // update vertex/normal/tangent:
        position += mul(trfm, float4(positions[batchedVertexIndex], 1)).xyz * inf.weight;
        normal += mul(trfm, float4(normals[batchedVertexIndex], 0)).xyz * inf.weight;
        tangent += float4(mul(trfm, float4(tangents[batchedVertexIndex].xyz, 0)).xyz, tangents[batchedVertexIndex].w) * inf.weight;
        color += colors[p] * inf.weight;
    }

    int base = i * 14;
    vertices.Store3( base<<2, asuint(position));
    vertices.Store3((base + 3)<<2, asuint(normal));
    vertices.Store4((base + 6)<<2, asuint(tangent));
    vertices.Store4((base + 10)<<2, asuint(color));
}